{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# WGAN-GP生成二次元原理与解析"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## 查看代码时可以查看一下这篇论文[论文地址](https://arxiv.org/pdf/1704.00028.pdf)，论文讲的就是wgan-gp，如果你没有gan的基础可以先了解了解gan的原理，实际上gan发展到wgan-gp在代码层面其实改变的并不多，后面我会对很多gan进行开源，望大家支持，谢谢！"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## WGAN-GP较前面的模型有哪些改进呢？\n",
    "1. 这次的模型，我们依然使用了`DCGAN`的网络结构，因为`WGAN-GP`的学习重点不在网络上\n",
    "\n",
    "2. WGAN 论文发现了 JS 散度导致 GAN 训练不稳定的问题，并引入了一种新的分布距离度量方法：Wasserstein 距离，也叫推土机距离(Earth-Mover Distance，简称 EM 距离)，它表示了从一个分布变换到另一个分布的最小代价\n",
    "\n",
    "3. 由于前WGAN并没有真正的实现`1-Lipschitz`，只有对任意输入x梯度都小于或等于1的时候，则该函数才是` 1-Lipschitz function`\n",
    "\n",
    "##  最后得到本篇精髓\n",
    "![](https://ai-studio-static-online.cdn.bcebos.com/1ffe098b66e7463f9c938b343f3d3f5685213396c9c448d8bee984746b51dcef)\n",
    "### **也就是想尽办法的让判别器认为真实的图片得到更高的分数所以需要加上负号，让生成的图片得到更低的分数，最后加上一个梯度的惩罚项，对于惩罚项的解释(在惩罚项中希望,如果越是满足`1-Lipschitz function`，惩罚就越少。事实证明这样做的效果是非常好的),现在大量GAN都能看到WGAN-GP的身影，所以它非常的重要**\n",
    "![](https://ai-studio-static-online.cdn.bcebos.com/ca021e41c11242eaafd52f8a5b2b13ca90a88b2a4f044bb59bc0720e84d85d25)\n",
    "### **其中有个超参数（1）表示超参数  本文设置为10**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 导入包paddle几乎都这么导\n",
    "import paddle.fluid as fluid\n",
    "import paddle\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2\n",
    "import random"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "###  生成器的实现，网络可以自行实现我这个为基本版随意写的，主要是领略GAN的思想，也是为了方便学习\n",
    "1. paddle中的静态图如果需要求导必须给定参数的名称，所以必须给每个参数命名一个独一无二的名字此处使用到了`fluid.unique_name.guard`用于生成参数名\n",
    "2. 生成器使用到了逆转卷积`conv2d_transpose`顾名思义是一种逆转的卷积和卷积正好相反，使用到了卷积这就是`DCGAN`与`传统GAN`的区别\n",
    "3. `conv2d_transpose`的paddling为`SAME`的时候有个非常简单的计算公式，图的大小 * 步长，为了方便我直接用的步长为2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## 网络结构\n",
    "* 1. 使用卷积网络进行下采样\n",
    "* ![](https://ai-studio-static-online.cdn.bcebos.com/0cb758271fde40b68d51642452f8b18f6b343dd78b8c4fc09c031c0d8db85a09)\n",
    "* 2. 使用反卷积进行上采样\n",
    "* ![](https://ai-studio-static-online.cdn.bcebos.com/addd8ef2fd7b402e9d56ea35d98b0666fa77f74feca4438f9af3d38421d01081)\n",
    "* 取消所有 pooling 层。G 网络中使用微步幅度卷积（fractionally strided convolution）\n",
    "* 代替 pooling 层，D 网络中使用步幅卷积（strided convolution）代替 pooling 层。 ·在 D 和 G 中均使用 batch normalization\n",
    "* 去掉 FC 层，使网络变为全卷积网络\n",
    "* G 网络中使用 ReLU 作为激活函数，最后一层使用 tanh\n",
    "* D 网络中使用 LeakyReLU 作为激活函数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 于是按照给定的网络结构可以写出适合自己图片大小的网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 3*96*96\n",
    "def Generator(x, name='G'):\n",
    "    # (-1,100,1,1)\n",
    "    def deconv(x, num_filters, filter_size=5,padding='SAME', stride=2, act='relu'):\n",
    "            x = fluid.layers.conv2d_transpose(x, num_filters=num_filters, filter_size=filter_size, stride=stride, padding=padding)\n",
    "            x = fluid.layers.batch_norm(x, momentum=0.8)\n",
    "            if act=='relu':\n",
    "                x = fluid.layers.relu(x)\n",
    "            elif act=='tanh':\n",
    "                x = fluid.layers.tanh(x)\n",
    "            return x\n",
    "    with fluid.unique_name.guard(name+'/'):\n",
    "        x = fluid.layers.reshape(x, (-1, 100, 1, 1))\n",
    "        x = deconv(x, num_filters=512, filter_size=6, stride=1, padding='VALID') # 6\n",
    "        x = deconv(x, num_filters=256) # 12\n",
    "        x = deconv(x, num_filters=128) #24\n",
    "        x = deconv(x, num_filters=64) #48\n",
    "        x = deconv(x, num_filters=3, act='tanh') #96\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 判别器与生成器差不多\r\n",
    "1. 就是将生成器反过来进行运算，最后使用全连接输出1维度的数即可"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def Discriminator(x, name='D'):\n",
    "    def conv(x, num_filters, momentum=0.8):\n",
    "            x = fluid.layers.conv2d(x, num_filters=num_filters, filter_size=5, stride=2, padding='SAME')\n",
    "            x = fluid.layers.batch_norm(x, momentum=momentum)\n",
    "            x = fluid.layers.leaky_relu(x, alpha=0.2)\n",
    "            x = fluid.layers.dropout(x, dropout_prob=0.25)\n",
    "            return x\n",
    "    with fluid.unique_name.guard(name+'/'):\n",
    "        # (3, 96, 96)\n",
    "        x = conv(x, num_filters=64) # 48\n",
    "        x = conv(x, num_filters=128) # 24\n",
    "        x = conv(x, num_filters=256) # 12\n",
    "        x = conv(x, num_filters=512) # 6\n",
    "        x = fluid.layers.pool2d(x, pool_type='avg', global_pooling=True)\n",
    "        x = fluid.layers.flatten(x)\n",
    "        x = fluid.layers.fc(x, size=1)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 静态图必须进行的步揍不多解释"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "d_program_1 = fluid.Program()\n",
    "d_program_2 = fluid.Program()\n",
    "g_program = fluid.Program()\n",
    "# 每批次数量\n",
    "batch_size = 128"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 判别器优化器实现\n",
    "1. 生成器的超参数有噪声的数量这里固定为100\n",
    "2. 通过输入的噪声生成图像(`Generator`)，并将生成的图片给判别器(`Discriminator`),我们给它一个低分\n",
    "3. 直接将真实的图片给判别器，并给它一个高分，由于优化器是梯度下降的方向，我们想要得到上升那么就给损失一个负号用于得到高分\n",
    "4. 给定真实图片和生成图片，完成梯度惩罚项GP, 并让他减小损失\n",
    "5. 训练需要固定生成器，让判别器学习\n",
    "6. 优化器 beta1=0, beta2=0.9\n",
    "7. 关于GP梯度惩罚项的实现[参考](https://github.com/PaddlePaddle/models/blob/release/1.7/PaddleCV/gan/trainer/STGAN.py第150行的函数)\n",
    "8. c为梯度惩罚项的超参数默认为 10.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "z_num = 100\n",
    "c = 10.0\n",
    "with fluid.program_guard(d_program_1):\n",
    "    # 传入参数用fluid.data\n",
    "    fake_z_1 = fluid.data(name='fake_z', dtype='float32', shape=(None, z_num))\n",
    "    # 通过生成器生成图片\n",
    "    img_fake_1 = Generator(fake_z_1)\n",
    "    # 判别器判断好坏\n",
    "    fake_ret_1 = Discriminator(img_fake_1)\n",
    "    # 判别器损失\n",
    "    loss_1 = fluid.layers.mean(fake_ret_1)\n",
    "\n",
    "    # 将名字为开头为D的参数取出来进行优化，也就是判别器的简写\n",
    "    parameter_list = []\n",
    "    for var in d_program_1.list_vars():\n",
    "        if fluid.io.is_parameter(var) and var.name.startswith('D'):\n",
    "            parameter_list.append(var.name)\n",
    "    \n",
    "    # 优化一下，优化原理为将判别器错误的图片尽可能的判别为小的\n",
    "    fluid.optimizer.AdamOptimizer(learning_rate=1e-4,beta1=0, beta2=0.9).minimize(loss_1, parameter_list=parameter_list)\n",
    "\n",
    "    # 用fluid.data传入真实的图片\n",
    "    img_real_1 = fluid.data(name='img_real', dtype='float32', shape=(None, 3, 96, 96))\n",
    "    # 用判别真实的图片\n",
    "    real_ret_1 = Discriminator(img_real_1)\n",
    "    # 损失函数\n",
    "    loss_2 = -1 * fluid.layers.mean(real_ret_1)\n",
    "    # 为啥要加个负号呢？因为我们想让真实的图片分数尽可能的大，优化器是让损失减小，加个负号即是增大\n",
    "    fluid.optimizer.AdamOptimizer(learning_rate=1e-4,beta1=0, beta2=0.9).minimize(loss_2, parameter_list=parameter_list)\n",
    "    \n",
    "    # 这里的惩罚项gp是copy于paddle的github\n",
    "    # https://github.com/PaddlePaddle/models/blob/release/1.7/PaddleCV/gan/trainer/STGAN.py第150行的函数\n",
    "    def _interpolate(a, b=None):\n",
    "        beta = fluid.layers.uniform_random_batch_size_like(\n",
    "            input=a, shape=a.shape, min=0.0, max=1.0)\n",
    "                \n",
    "        mean = fluid.layers.reduce_mean(\n",
    "            a, dim=list(range(len(a.shape))), keep_dim=True)\n",
    "        input_sub_mean = fluid.layers.elementwise_sub(a, mean, axis=0)\n",
    "        var = fluid.layers.reduce_mean(\n",
    "            fluid.layers.square(input_sub_mean),\n",
    "            dim=list(range(len(a.shape))),\n",
    "            keep_dim=True)\n",
    "        b = beta * fluid.layers.sqrt(var) * 0.5 + a\n",
    "        shape = [a.shape[0]] \n",
    "        alpha = fluid.layers.uniform_random_batch_size_like(\n",
    "            input=a, shape=shape, min=0.0, max=1.0)\n",
    "\n",
    "        inner = fluid.layers.elementwise_mul((b-a), alpha, axis=0) + a\n",
    "        return inner\n",
    "\n",
    "    x = _interpolate(img_real_1, img_fake_1)\n",
    "    pred = Discriminator(x)\n",
    "    \n",
    "    vars = []\n",
    "    for var in fluid.default_main_program().list_vars():\n",
    "        if fluid.io.is_parameter(var) and var.name.startswith(\"D\"):\n",
    "            vars.append(var.name)\n",
    "    grad = fluid.gradients(pred, x, no_grad_set=vars)[0]\n",
    "    grad_shape = grad.shape\n",
    "    grad = fluid.layers.reshape(\n",
    "        grad, [-1, grad_shape[1] * grad_shape[2] * grad_shape[3]])\n",
    "    epsilon = 1e-16\n",
    "    norm = fluid.layers.sqrt(\n",
    "        fluid.layers.reduce_sum(\n",
    "            fluid.layers.square(grad), dim=1) + epsilon)\n",
    "    gp = fluid.layers.reduce_mean(fluid.layers.square(norm - 1.0)) * c\n",
    "    d_loss = loss_1 + loss_2 + gp\n",
    "    # clip = fluid.clip.GradientClipByValue(min=CLIP[0], max=CLIP[1])\n",
    "    # 优化即可\n",
    "    fluid.optimizer.AdamOptimizer(learning_rate=1e-4,beta1=0, beta2=0.9).minimize(gp, parameter_list=parameter_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 生成器优化的实现\n",
    "1. 输入噪声给生成器，因为我们想让生成器生成一张真实的图片，所以我们需要固定判别器，让生成器生成一张趋于真实的图片，所以要让他向分数高的方向走，所以也要加个负号"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "with fluid.program_guard(g_program):\n",
    "    fake_z_2 = fluid.data(name='fake_z', dtype='float32', shape=(None, z_num))\n",
    "    \n",
    "    img_fake_2 = Generator(fake_z_2)\n",
    "    # 克隆一下，并固定参数，用于预测\n",
    "    test_program = g_program.clone(for_test=True)\n",
    "    fake_ret_2 = Discriminator(img_fake_2)\n",
    "    \n",
    "\n",
    "    # loss = fluid.layers.sigmoid_cross_entropy_with_logits(x=fake_ret, label=fluid.layers.ones_like(real_ret))\n",
    "    avg_loss_2 = -1 * fluid.layers.mean(fake_ret_2)\n",
    "    \n",
    "    parameter_list = []\n",
    "    for var in d_program_1.list_vars():\n",
    "        if fluid.io.is_parameter(var) and var.name.startswith('G'):\n",
    "            parameter_list.append(var.name)\n",
    "    \n",
    "    # 这里是让分数尽可能的大，所以加上负号\n",
    "    fluid.optimizer.AdamOptimizer(learning_rate=1e-4, beta1=0, beta2=0.9).minimize(avg_loss_2, parameter_list=parameter_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 准备训练前的配置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "use_gpu = True\n",
    "if use_gpu:\n",
    "    place = fluid.CUDAPlace(0)\n",
    "else:\n",
    "    place = fluid.CPUPlace()\n",
    "\n",
    "exe = fluid.executor.Executor(place=place)\n",
    "exe.run(fluid.default_startup_program())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 数据读入函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import os\n",
    "!unzip -oq /home/aistudio/data/data17962/二次元人物头像.zip\n",
    "fullpath = '/home/aistudio/faces'\n",
    "def get_batch():\n",
    "    filenames = os.listdir(fullpath)\n",
    "    random.shuffle(filenames)\n",
    "    img_list = []\n",
    "    label_list = []\n",
    "    for i, filename in enumerate(filenames):\n",
    "        fullname = os.path.join(fullpath, filename)\n",
    "        img = cv2.imread(fullname, cv2.IMREAD_COLOR)\n",
    "        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "        img = img / 255.*2 - 1\n",
    "        img = np.transpose(img, (2,0,1))\n",
    "        img_list.append(img)\n",
    "        label_list.append(1.)\n",
    "        if (i+1) % batch_size == 0:\n",
    "            yield np.array(img_list).astype('float32'), np.array(label_list)\n",
    "            img_list = []\n",
    "            label_list = []"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 快照继续训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "try:\r\n",
    "    fluid.io.load_params(exe, dirname='/home/aistudio/out_params', main_program=fluid.default_startup_program())\r\n",
    "except:\r\n",
    "    pass\r\n",
    "def get_cur():\r\n",
    "    try:\r\n",
    "        img_names = np.array([i.strip('.jpg').split('_') for i in os.listdir('/home/aistudio/out_params/') if '.jpg' in i]).astype('int')\r\n",
    "    except:\r\n",
    "        return [0, 0]\r\n",
    "    if len(img_names) == 0:\r\n",
    "        return [0, 0]\r\n",
    "    return np.max(img_names, axis=0)\r\n",
    "cur_process = get_cur()\r\n",
    "epoch_pro = cur_process[0]\r\n",
    "step_pro = cur_process[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 训练开始\n",
    "1. 由于判别器学习得快一点所以需要，为了保持同步我们可以将生成器进行多次得训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: '/home/aistudio/faces'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_3953/3177579948.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mtrain_d\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m     \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mget_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m         \u001b[0mepoch\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mepoch_pro\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m         \u001b[0mstep\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mstep_pro\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_3953/2752872269.py\u001b[0m in \u001b[0;36mget_batch\u001b[0;34m()\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mfullpath\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'/home/aistudio/faces'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m     \u001b[0mfilenames\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlistdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfullpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m     \u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilenames\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m     \u001b[0mimg_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/home/aistudio/faces'"
     ]
    }
   ],
   "source": [
    "epochs = 20000\n",
    "train_d = 2\n",
    "for epoch in range(epochs):\n",
    "    for step,(x, y) in enumerate(get_batch()):\n",
    "        epoch += epoch_pro\n",
    "        step += step_pro\n",
    "        fake_z = np.random.uniform(size=(x.shape[0], z_num), low=-1, high=1).astype('float32')\n",
    "        [loss_1_, loss_2_, gp_] = exe.run(program=d_program_1, \n",
    "                                feed={'fake_z':fake_z,'img_real':x}, \n",
    "                                fetch_list=[loss_1, loss_2, gp])\n",
    "\n",
    "\n",
    "        # 生成器训练\n",
    "        for _ in range(train_d):\n",
    "            fake_z = np.random.uniform(size=(x.shape[0], z_num), low=-1, high=1).astype('float32')\n",
    "            [g_loss] = exe.run(program=g_program, \n",
    "                    feed={'fake_z':fake_z}, \n",
    "                    fetch_list=[avg_loss_2])\n",
    "        \n",
    "        # 100次进行预测一次\n",
    "        if step % 100 == 0:\n",
    "            print(loss_1_, loss_2_, gp_, g_loss)\n",
    "            print('[Training] epoch:{} step:{} d_loss:{} g_loss:{}'.format(epoch, step, loss_1_+loss_2_+gp_, g_loss))\n",
    "            # 快照\n",
    "            fluid.io.save_params(exe, dirname='/home/aistudio/out_params', main_program=fluid.default_startup_program())    \n",
    "            fake_z = np.random.uniform(size=(z_num, z_num), low=-1, high=1).astype('float32')\n",
    "            [pre_im] = exe.run(program=test_program, \n",
    "                    feed={'fake_z':fake_z}, \n",
    "                    fetch_list=[img_fake_2])\n",
    "            pre_im = (np.transpose(pre_im, (0, 2, 3, 1))+1) / 2\n",
    "            \n",
    "            # 准备一个画布用于存放生成的图片\n",
    "            images = np.zeros((960, 960, 3))\n",
    "            for h in range(10):\n",
    "                for w in range(10):\n",
    "                    images[h*96:(h+1)*96, w*96:(w+1)*96] = pre_im[(h*10)+w]\n",
    "            plt.imsave('/home/aistudio/out_params/{}_{}.jpg'.format(epoch, step), images)\n",
    "            plt.imshow(images, cmap='gray')\n",
    "            plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import os\r\n",
    "import cv2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 最后可以将我们得图片生成过程做成视频，感受GAN带来得神奇"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "video_dir = './a.mp4'\r\n",
    "fps = 5\r\n",
    "img_size = (960, 960)\r\n",
    "\r\n",
    "fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')\r\n",
    "video_writer = cv2.VideoWriter(video_dir, fourcc, fps, img_size)\r\n",
    "file_list = os.listdir('out_params')\r\n",
    "file_list_ = []\r\n",
    "for i in file_list:\r\n",
    "    if '.jpg' in i:\r\n",
    "        epoch, step = i.strip('.jpg').split('_')\r\n",
    "        file_list_.append(int(epoch) * 1000 + int(step))\r\n",
    "file_list_.sort()\r\n",
    "for i in file_list_:\r\n",
    "    img_name = os.path.join('out_params', '%s_%s' % (i // 1000, i % 1000)+\".jpg\")\r\n",
    "    img_array = cv2.imread(img_name, cv2.IMREAD_COLOR)\r\n",
    "    video_writer.write(img_array)\r\n",
    "video_writer.release()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 简单总结一下吧\n",
    "1. 也是第一次在paddle上写完整的网络，通过查文档问老师提issue，也是在大家的帮助下终于完成了项目\n",
    "2. 就本次实验来说，让我更加了解WGAN-GP的实现,也为是实现其他GAN模型奠定了基础\n",
    "3. 但是实验过程中我发现，生成效果并不是非常好，我猜想这和网络有直接关系，希望可以阅读跟多的文章借鉴跟多文章来生成更高清的二次元\n",
    "# 介绍自己\n",
    "* 喜欢编程，能全栈\n",
    "* 选择Python，更热爱AI\n",
    "*  来AI Studio互粉吧~等你哦~ [主页](https://aistudio.baidu.com/aistudio/personalcenter/thirdview/340127)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "py35-paddle1.2.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
